import torch
import torch as th
import torch.nn as nn

from model.util import (
    conv_nd,
    linear,
    zero_module,
    timestep_embedding,
    exists
)
from model.attention import SpatialTransformer
from model.unet import (
    TimestepEmbedSequential, ResBlock, Downsample, AttentionBlock, UNetModel
)

import torch.nn.functional as F
from collections import OrderedDict


class ControlledUnetModel(UNetModel):

    def forward(self, x, timesteps=None, context=None, control=None, only_mid_control=False, **kwargs):
        hs = []
        t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
        emb = self.time_embed(t_emb)
        h = x.type(self.dtype)
        for module in self.input_blocks:
            h = module(h, emb, context)
            hs.append(h)
        h = self.middle_block(h, emb, context)

        if control is not None:
            h += control.pop()

        for i, module in enumerate(self.output_blocks):
            if only_mid_control or control is None:
                h = torch.cat([h, hs.pop()], dim=1)
            else:
                h = torch.cat([h, hs.pop() + control.pop()], dim=1)
            h = module(h, emb, context)

        h = h.type(x.dtype)
        return self.out(h)


class Gate(nn.Module):
    def __init__(self, dim = 8, sparse = False):
        super(Gate, self).__init__()
        self.dim = dim
        self.sparse = sparse
        self.mlp = nn.Sequential(
            nn.Conv2d(dim, 4, 1, bias=False), 
			nn.SiLU(),
			nn.Conv2d(4, 1, 1, bias=False),
            nn.Sigmoid()
        )

    def forward(self, f1, f2): # f1: c_lq     f2: c_stage1
        feat_cat = torch.cat((f1, f2), dim=1)
        attn = self.mlp(feat_cat)
        if self.sparse:
            attn = F.relu(attn - 0.1) * (1 / (1 - 0.1)) # sparse ops
        return f1 * (1 - attn), f2 * attn
    
class Gated_Fusion(nn.Module): # old version for test only
    def __init__(self, dim=8):
        super(Gated_Fusion, self).__init__()
        self.mlp = nn.Sequential(
            nn.Conv2d(dim, 4, 1, bias=False), 
			nn.SiLU(),
			nn.Conv2d(4, 1, 1, bias=False)
        )
        self.sgimoid = nn.Sigmoid() # typo
    
    def forward(self, f1, f2):
        feat_cat = torch.cat((f1, f2), dim=1)
        attn = self.sgimoid(self.mlp(feat_cat))
        return f1 * attn, f2 * (1 - attn)
    
class SFTLayer(nn.Module):
    def __init__(self):
        super(SFTLayer, self).__init__()
        self.SFT_scale_conv0 = nn.Conv2d(16, 16, 1)
        self.SFT_scale_conv1 = nn.Conv2d(16, 32, 1)
        self.SFT_shift_conv0 = nn.Conv2d(16, 16, 1)
        self.SFT_shift_conv1 = nn.Conv2d(16, 32, 1)
        self.input_block_feat = nn.Sequential(
            conv_nd(2, 4, 16, 3, padding=1),
            nn.SiLU(),
            conv_nd(2, 16, 32, 3, padding=1),
            # nn.SiLU(),
            # conv_nd(2, 32, 64, 3, padding=1),
        )
        # self.output_block = nn.Conv2d(64, 4, 1)

    def forward(self, feat, cond):
        feat = self.input_block_feat(feat)
        scale = self.SFT_scale_conv1(F.leaky_relu(self.SFT_scale_conv0(cond), 0.1, inplace=True))
        shift = self.SFT_shift_conv1(F.leaky_relu(self.SFT_shift_conv0(cond), 0.1, inplace=True))
        # return self.output_block(feat * (scale + 1) + shift)
        return feat * (scale + 1) + shift
    

class QuickGELU(nn.Module):
    """A fast approximation of the GELU activation function."""
    def forward(self, x: torch.Tensor):
        """Apply the QuickGELU activation to the input tensor."""
        return x * torch.sigmoid(1.702 * x)


class LayerNorm(nn.LayerNorm):
    """Subclass torch's LayerNorm to handle fp16."""
    def forward(self, x: torch.Tensor):
        """Apply LayerNorm and preserve the input dtype."""
        orig_type = x.dtype
        ret = super().forward(x)
        return ret.type(orig_type)

class ResidualAttentionBlock(nn.Module):
    """A transformer-style block with self-attention and an MLP."""
    def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None):
        super().__init__()
        self.attn = nn.MultiheadAttention(d_model, n_head)
        self.ln_1 = LayerNorm(d_model)
        self.mlp = nn.Sequential(
            OrderedDict([("c_fc", nn.Linear(d_model, d_model * 2)), ("gelu", QuickGELU()),
                         ("c_proj", nn.Linear(d_model * 2, d_model))])
        )
        self.ln_2 = LayerNorm(d_model)
        self.attn_mask = attn_mask

    def attention(self, x: torch.Tensor):
        """Apply self-attention to the input tensor."""
        self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None
        return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0]

    def forward(self, x: torch.Tensor):
        """Forward pass through the residual attention block."""
        x = x + self.attention(self.ln_1(x))
        x = x + self.mlp(self.ln_2(x))
        return x

class Align(nn.Module): # the modulate part
    def __init__(self, dim = 32):
        super(Align, self).__init__()

        num_trans_layer = 2
        num_trans_head = 4
        num_trans_channel = 2 * dim
        num_proj_channel = dim

        self.conv_in_hint = nn.Conv2d(4, dim, kernel_size=3, padding=1)
        self.conv_in_zt = nn.Conv2d(4, dim, kernel_size=3, padding=1)

        self.trans = nn.Sequential(*[ResidualAttentionBlock(num_trans_channel, num_trans_head) for _ in range(num_trans_layer)])
        # self.linear = zero_module(nn.Linear(num_trans_channel, num_proj_channel))
        self.linear = nn.Linear(num_trans_channel, num_proj_channel)

        self.conv_out = nn.Conv2d(dim, 4, kernel_size=3, padding=1)

    def forward(self, hint, z_t):
        feat_hint = self.conv_in_hint(hint)
        feat_zt = self.conv_in_zt(z_t)

        batch_size, channel, height, width = feat_hint.shape
        feat_cat = torch.cat((feat_hint, feat_zt), dim=1).view(batch_size, 2 * channel, height * width).transpose(1, 2)
        feat_alpha = self.linear(self.trans(feat_cat)).transpose(1, 2).view(batch_size, channel, height, width)

        # out = self.conv_out(torch.cat((feat_zt, feat_alpha), dim=1))
        out = self.conv_out(feat_zt + feat_alpha)
        
        return out

    

class ControlNet(nn.Module):

    def __init__(
        self,
        image_size, # 32
        in_channels, 
        model_channels, # 320
        hint_channels,
        num_res_blocks, # 2
        attention_resolutions, # [4, 2, 1]
        dropout=0,
        channel_mult=(1, 2, 4, 8),
        conv_resample=True,
        dims=2,
        use_checkpoint=False,
        use_fp16=False,
        num_heads=-1,
        num_head_channels=-1, # 64
        num_heads_upsample=-1,
        use_scale_shift_norm=False,
        resblock_updown=False,
        use_new_attention_order=False,
        use_spatial_transformer=False,  # custom transformer support
        transformer_depth=1,  # custom transformer support
        context_dim=None,  # custom transformer support 1024
        n_embed=None,  # custom support for prediction of discrete ids into codebook of first stage vq model
        legacy=True, # false
        disable_self_attentions=None,
        num_attention_blocks=None,
        disable_middle_self_attn=False,
        use_linear_in_transformer=False,
    ):
        super().__init__()
        if use_spatial_transformer:
            assert context_dim is not None, 'Fool!! You forgot to include the dimension of your cross-attention conditioning...'

        if context_dim is not None:
            assert use_spatial_transformer, 'Fool!! You forgot to use the spatial transformer for your cross-attention conditioning...'
            from omegaconf.listconfig import ListConfig
            if type(context_dim) == ListConfig:
                context_dim = list(context_dim)

        if num_heads_upsample == -1:
            num_heads_upsample = num_heads

        if num_heads == -1:
            assert num_head_channels != -1, 'Either num_heads or num_head_channels has to be set'

        if num_head_channels == -1:
            assert num_heads != -1, 'Either num_heads or num_head_channels has to be set'

        self.dims = dims
        # self.image_size = image_size
        self.in_channels = in_channels
        self.model_channels = model_channels
        if isinstance(num_res_blocks, int):
            self.num_res_blocks = len(channel_mult) * [num_res_blocks] # 4*[2] = [2,2,2,2]
        else:
            if len(num_res_blocks) != len(channel_mult):
                raise ValueError("provide num_res_blocks either as an int (globally constant) or "
                                 "as a list/tuple (per-level) with the same length as channel_mult")
            self.num_res_blocks = num_res_blocks
        if disable_self_attentions is not None:
            # should be a list of booleans, indicating whether to disable self-attention in TransformerBlocks or not
            assert len(disable_self_attentions) == len(channel_mult)
        if num_attention_blocks is not None:
            assert len(num_attention_blocks) == len(self.num_res_blocks)
            assert all(map(lambda i: self.num_res_blocks[i] >= num_attention_blocks[i], range(len(num_attention_blocks))))
            print(f"Constructor of UNetModel received num_attention_blocks={num_attention_blocks}. "
                  f"This option has LESS priority than attention_resolutions {attention_resolutions}, "
                  f"i.e., in cases where num_attention_blocks[i] > 0 but 2**i not in attention_resolutions, "
                  f"attention will still not be set.")

        self.attention_resolutions = attention_resolutions
        self.dropout = dropout
        self.channel_mult = channel_mult
        self.conv_resample = conv_resample
        self.use_checkpoint = use_checkpoint
        self.dtype = th.float16 if use_fp16 else th.float32
        self.num_heads = num_heads
        self.num_head_channels = num_head_channels
        self.num_heads_upsample = num_heads_upsample
        self.predict_codebook_ids = n_embed is not None

        time_embed_dim = model_channels * 4
        self.time_embed = nn.Sequential(
            linear(model_channels, time_embed_dim),
            nn.SiLU(),
            linear(time_embed_dim, time_embed_dim),
        )

        # self.gated_fusion = Gated_Fusion()
        self.gate = Gate()
        self.align_1 = Align()
        self.align_2 = Align()

        self.input_blocks = nn.ModuleList(
            [
                TimestepEmbedSequential(
                    conv_nd(dims, in_channels + hint_channels * 2, model_channels, 3, padding=1)
                )
            ]
        )
        self.zero_convs = nn.ModuleList([self.make_zero_conv(model_channels)])

        self._feature_size = model_channels
        input_block_chans = [model_channels]
        ch = model_channels
        ds = 1
        for level, mult in enumerate(channel_mult): # [1, 2, 4, 4]
            for nr in range(self.num_res_blocks[level]):# [2,2,2,2]
                layers = [
                    ResBlock(
                        ch,
                        time_embed_dim,
                        dropout,
                        out_channels=mult * model_channels,
                        dims=dims,
                        use_checkpoint=use_checkpoint,
                        use_scale_shift_norm=use_scale_shift_norm,
                    )
                ]
                ch = mult * model_channels
                if ds in attention_resolutions: # [4, 2, 1]
                    if num_head_channels == -1:
                        dim_head = ch // num_heads
                    else:
                        num_heads = ch // num_head_channels
                        dim_head = num_head_channels
                    if legacy:
                        # num_heads = 1
                        dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
                    if exists(disable_self_attentions):
                        disabled_sa = disable_self_attentions[level]
                    else:
                        disabled_sa = False

                    if not exists(num_attention_blocks) or nr < num_attention_blocks[level]:
                        layers.append(
                            AttentionBlock(
                                ch,
                                use_checkpoint=use_checkpoint,
                                num_heads=num_heads,
                                num_head_channels=dim_head,
                                use_new_attention_order=use_new_attention_order,
                            ) if not use_spatial_transformer else SpatialTransformer(
                                ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim,
                                disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer,
                                use_checkpoint=use_checkpoint
                            )
                        )
                self.input_blocks.append(TimestepEmbedSequential(*layers))
                self.zero_convs.append(self.make_zero_conv(ch))
                self._feature_size += ch
                input_block_chans.append(ch)
            if level != len(channel_mult) - 1:
                out_ch = ch
                self.input_blocks.append(
                    TimestepEmbedSequential(
                        ResBlock(
                            ch,
                            time_embed_dim,
                            dropout,
                            out_channels=out_ch,
                            dims=dims,
                            use_checkpoint=use_checkpoint,
                            use_scale_shift_norm=use_scale_shift_norm,
                            down=True,
                        )
                        if resblock_updown
                        else Downsample(
                            ch, conv_resample, dims=dims, out_channels=out_ch
                        )
                    )
                )
                ch = out_ch
                input_block_chans.append(ch)
                self.zero_convs.append(self.make_zero_conv(ch))
                ds *= 2
                self._feature_size += ch

        if num_head_channels == -1:
            dim_head = ch // num_heads
        else:
            num_heads = ch // num_head_channels
            dim_head = num_head_channels
        if legacy:
            # num_heads = 1
            dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
        self.middle_block = TimestepEmbedSequential(
            ResBlock(
                ch,
                time_embed_dim,
                dropout,
                dims=dims,
                use_checkpoint=use_checkpoint,
                use_scale_shift_norm=use_scale_shift_norm,
            ),
            AttentionBlock(
                ch,
                use_checkpoint=use_checkpoint,
                num_heads=num_heads,
                num_head_channels=dim_head,
                use_new_attention_order=use_new_attention_order,
            ) if not use_spatial_transformer else SpatialTransformer(  # always uses a self-attn
                ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim,
                disable_self_attn=disable_middle_self_attn, use_linear=use_linear_in_transformer,
                use_checkpoint=use_checkpoint
            ),
            ResBlock(
                ch,
                time_embed_dim,
                dropout,
                dims=dims,
                use_checkpoint=use_checkpoint,
                use_scale_shift_norm=use_scale_shift_norm,
            ),
        )
        self.middle_block_out = self.make_zero_conv(ch)
        self._feature_size += ch

    def make_zero_conv(self, channels):
        return TimestepEmbedSequential(zero_module(conv_nd(self.dims, channels, channels, 1, padding=0)))

    
    def forward(self, x, hint, timesteps, context, **kwargs):# z_t, c_img, t, c_txt,
        t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
        emb = self.time_embed(t_emb)

        hint_1, hint_2 = hint[0], hint[1] # c_lq, c_depth
    
		# modulate
        hint_1 = self.align_1(hint_1, x)
        hint_2 = self.align_2(hint_2, x)
        
		# gate
        hint_1, hint_2 = self.gate(hint_1, hint_2)

        # together
        x = torch.cat((x, hint_1, hint_2), dim=1)

        outs = []

        h, emb, context = map(lambda t: t.type(self.dtype), (x, emb, context))
        for module, zero_conv in zip(self.input_blocks, self.zero_convs):
            h = module(h, emb, context)
            outs.append(zero_conv(h, emb, context))

        h = self.middle_block(h, emb, context)
        outs.append(self.middle_block_out(h, emb, context))

        return outs
